{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ada96372-65af-4b7a-ad62-ef74352328a3",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import os\n",
    "import sys\n",
    "\n",
    "# Add the project's files to the python path\n",
    "# file_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))  # for .py script\n",
    "file_path = os.path.dirname(os.path.abspath(''))  # for .ipynb notebook\n",
    "sys.path.append(file_path)\n",
    "\n",
    "import torch\n",
    "from src.datasets.s3dis import CLASS_NAMES, CLASS_COLORS, STUFF_CLASSES\n",
    "from src.datasets.s3dis import S3DIS_NUM_CLASSES as NUM_CLASSES\n",
    "from src.transforms import *\n",
    "from src.data import NAG"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "585172ca-bb12-40be-85ff-5e12f5c2a06e",
   "metadata": {},
   "source": [
    "The main data structures of this project are `Data` and `NAG`.\n",
    "\n",
    "`Data` stores a single-level graph. \n",
    "It inherits from `torch_geometric`'s `Data` and has a similar behavior (see the [official documentation](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.data.Data.html#torch_geometric.data.Data) for more on this). \n",
    "Important specificities of our `Data` object are:\n",
    "- `Data.super_index` stores the parent's index for each node in `Data`\n",
    "- `Data.sub` holds a `Cluster` object indicating the children of each node in `Data`\n",
    "- `Data.to_trimmed()` works like `torch_geometric`'s `Data.coalesce()` with the additional constraint that (i,j) and (j,i) edges are considered duplicates\n",
    "- `Data.save()` and `Data.load()` allow optimized, memory-friedly I/O operations\n",
    "- `Data.select()` indexes the nodes à la numpy\n",
    "\n",
    "`NAG` (Nested Acyclic Graph) stores the hierarchical partition in the form of a list of `Data` objects.\n",
    "Important specificities of our `Data` object are:\n",
    "-  `Data` objects held by a `NAG` are indexed using **absolute** indices. \n",
    "This means that if you load level 1 but not level 0 (the atomic level), accessing \n",
    "`NAG[0]` will raise an error.\n",
    "- `NAG.start_i_level` parameter indicates the first level held by the `NAG`.\n",
    "- `NAG[i]` returns a `Data` object holding the partition level `i`. \n",
    "- `NAG.get_super_index()` returns the index mapping nodes from any level `i` to `j` with `i<j`\n",
    "- `NAG.get_sampling()` produces indices for sampling the superpoints with certain constraints\n",
    "- `NAG.save()` and `NAG.load()` allow optimized, memory-friedly I/O operations\n",
    "- `NAG.select()` indexes the nodes of a specified partition level à la numpy and updates the rest of the `NAG` structure accordingly"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eb14cf3b-e468-4ae3-bd85-cda6d8437103",
   "metadata": {},
   "source": [
    "## Load a NAG"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca481f05-82a0-4179-8c0b-0d7320f303b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "nag = NAG.load('demo_nag_v3.h5', non_fp_to_long=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb4ab507-e83d-4d16-aa6f-180b0e95e196",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Print general info about the NAG\n",
    "print(nag)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f67bdb0c-5fe5-4048-bd4a-4f029b1906c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Loop over the partition levels and print general info about each Data\n",
    "for i_level, data in enumerate(nag):\n",
    "    print(f\"Level-{i_level}:\\n{data}\\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9c1ab511-dc3a-4750-bde8-5c6fcae7a516",
   "metadata": {},
   "source": [
    "## Visualizing a NAG"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ae5bec5-01d2-41e3-90ec-244a61cac5d3",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Visualize the hierarchical partition\n",
    "nag.show( \n",
    "    class_names=CLASS_NAMES,\n",
    "    class_colors=CLASS_COLORS, \n",
    "    stuff_classes=STUFF_CLASSES,\n",
    "    num_classes=NUM_CLASSES,\n",
    "    max_points=100000,\n",
    "    centroids=True,\n",
    "    h_edge=True\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15a653dd-291c-43c6-8c83-78ebff2d00b9",
   "metadata": {},
   "source": [
    "## Selecting a portion of the hierarchical partition\n",
    "The NAG structure can be subselected using `nag.select()`.\n",
    "\n",
    "This function expects an `int` specifying the partition level from which we should select, along with an index or a mask in the form or a `list`, `numpy.ndarray`, `torch.Tensor`, or `slice`.\n",
    "This index/mask describes which nodes to select at the specified level.\n",
    "\n",
    "The output NAG will only contain children, parents and edges of the selected nodes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d86de234-3da8-4d10-a373-7dce47c514da",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Pick a center and radius for the spherical sample\n",
    "center = nag[0].pos.mean(dim=0)\n",
    "radius = 1\n",
    "\n",
    "# Create a mask on level-0 (i.e. points) to be used for indexing the NAG \n",
    "# structure\n",
    "mask = torch.where(torch.linalg.norm(nag[0].pos - center, dim=1) < radius)[0]\n",
    "\n",
    "# Subselect the hierarchical partition based on the level-0 mask\n",
    "nag_visu = nag.select(0, mask)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ef97ce1-527d-418d-b273-839c721ede6e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Visualize the sample\n",
    "nag_visu.show(\n",
    "    class_names=CLASS_NAMES,\n",
    "    class_colors=CLASS_COLORS, \n",
    "    stuff_classes=STUFF_CLASSES,\n",
    "    num_classes=NUM_CLASSES,\n",
    "    max_points=100000,\n",
    "    centroids=True,\n",
    "    h_edge=True\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cf2a0ff0-ea84-47d1-8fdf-30b425955a36",
   "metadata": {},
   "source": [
    "**Tip** - the above example is used to illustrate the `nag.select()` method, which is not limited to a mask for spherical samplings. However, since visualizing a spherical sample of a large cloud is a common operation, the `show()` function allows you to do the same as above, by specifying a `radius` and a `center`. See `show()` documentation for more details."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "spt_test",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.20"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
